{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SparseSinkhorn Solver"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from lib.header_notebook import *\n",
    "import Solvers.Sinkhorn as Sinkhorn\n",
    "import lib.header_params_Sinkhorn\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# verbosity: disable iteration output, may become really slow in notebooks\n",
    "paramsVerbose={\n",
    "        \"solve_overview\":True,\\\n",
    "        \"solve_update\":True,\\\n",
    "        \"solve_kernel\":True,\\\n",
    "        \"solve_iterate\":False\\\n",
    "        }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The full enhanced Sinkhorn algorithm requires a lot of configuration parameters:\n",
    "\n",
    "This includes modelling aspects, such was the transport-model, e.g. whether to compute standard optimal transport, Wasserstein-Fisher-Rao (including the balancing parameter between transport and growth), gradient flows or barycenters.\n",
    "\n",
    "But also computational aspects: which data structure for the kernel (dense or truncated; the truncation threshold), parameters for the log-stabilization scheme, for epsilon-scaling, for the multi-scale scheme, error tolerances etc.\n",
    "\n",
    "These parameters are stored in config files in the subdirectory cfg/ and read by this script. Several test problems have been prepared. By setting params[\"setup_tag\"] one can choose an example (see below) and subsequently run the rest of the script for solving it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# setup parameter managment\n",
    "\n",
    "params=lib.header_params_Sinkhorn.getParamsDefaultTransport()\n",
    "paramsListCommandLine,paramsListCFGFile=lib.header_params_Sinkhorn.getParamListsTransport()\n",
    "\n",
    "# choose setup_tag. This specifies, from which config file the problem parameters are loaded.\n",
    "# Several examples have been prepared. Uncomment the one you would like to try.\n",
    "\n",
    "\n",
    "############################################################################\n",
    "# Compare successive enhancements of simple algorithm on 64x64 test image\n",
    "# One example of data for Figure 2, at eps=0.1 h^2\n",
    "############################################################################\n",
    "## 1: log-domain stabilization: careful: takes a frustratingly long time\n",
    "#params[\"setup_tag\"]=\"cfg/Sinkhorn/CompareEnhancements/1\"\n",
    "## 2: log-domain stabilization, epsilon scaling\n",
    "#params[\"setup_tag\"]=\"cfg/Sinkhorn/CompareEnhancements/2\"\n",
    "## 3: log-domain stabilization, epsilon scaling, truncated kernel\n",
    "#params[\"setup_tag\"]=\"cfg/Sinkhorn/CompareEnhancements/3\"\n",
    "## 4: log-domain stabilization, epsilon scaling, truncated kernel, coarse-to-fine\n",
    "#params[\"setup_tag\"]=\"cfg/Sinkhorn/CompareEnhancements/4\"\n",
    "\n",
    "\n",
    "############################################################################\n",
    "# Large example, on two 256x256 images\n",
    "# One example of data for Figure 3, at eps=0.1 h^2, error=1E-6\n",
    "############################################################################\n",
    "# Standard Wasserstein distance example, as used for Fig 3\n",
    "params[\"setup_tag\"]=\"cfg/Sinkhorn/ImageBenchmark/OT_256\"\n",
    "# For comparison, a similar example with Wasserstein-Fisher-Rao distance is also given\n",
    "#params[\"setup_tag\"]=\"cfg/Sinkhorn/ImageBenchmark/WF_256\"\n",
    "\n",
    "\n",
    "############################################################################\n",
    "# Compare Wasserstein with Wasserstein-Fisher-Rao distance\n",
    "# An example with moving Gaussian blobs of different mass.\n",
    "# Run both and compare the resulting displacement interpolations\n",
    "# Standard Optimal transport will have to transfer mass between Gaussians of different weight.\n",
    "# Wasserstein-Fisher-Rao can compensate the difference by local growth / annihilation.\n",
    "############################################################################\n",
    "#params[\"setup_tag\"]=\"cfg/Sinkhorn/Gaussians/OT_128_000-001\"\n",
    "#params[\"setup_tag\"]=\"cfg/Sinkhorn/Gaussians/WF_128_000-001\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "params[\"setup_cfgfile\"]=params[\"setup_tag\"]+\".txt\"\n",
    "\n",
    "# load parameters from config file\n",
    "params.update(ScriptTools.readParameters(params[\"setup_cfgfile\"],paramsListCFGFile))\n",
    "\n",
    "# interpreting some parameters\n",
    "\n",
    "# totalMass regulates, whether marginals should be normalized or not.\n",
    "if params[\"setup_totalMass\"]<0:\n",
    "    params[\"setup_totalMass\"]=None\n",
    "# finest level for multi-scale algorithm\n",
    "params[\"hierarchy_lBottom\"]=params[\"hierarchy_depth\"]+1\n",
    "\n",
    "\n",
    "print(\"final parameter settings\")\n",
    "for k in sorted(params.keys()):\n",
    "    print(\"\\t\",k,params[k])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Problem Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# define problem: setup marginals\n",
    "\n",
    "def loadProblem(filename):\n",
    "    img=sciio.loadmat(filename)[\"a\"]\n",
    "    return img\n",
    "\n",
    "def setupDensity(img,posScale,totalMass,constOffset,keepZero):\n",
    "    (mu,pos)=OTTools.processDensity_Grid(img,totalMass=totalMass,constOffset=constOffset,keepZero=keepZero)\n",
    "    pos=pos/posScale\n",
    "    return (mu,pos,img.shape)\n",
    "\n",
    "problemData=[setupDensity(loadProblem(filename),posScale=params[\"setup_posScale\"],\\\n",
    "        totalMass=params[\"setup_totalMass\"],constOffset=params[\"setup_constOffset\"],keepZero=False)\\\n",
    "        for filename in [params[\"setup_f1\"],params[\"setup_f2\"]]]\n",
    "\n",
    "\n",
    "nProblems=len(problemData)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# visualize marginals\n",
    "fig=plt.figure(figsize=(4*nProblems,4))\n",
    "for i in range(nProblems):\n",
    "    img=OTTools.ProjectInterpolation2D(problemData[i][1],problemData[i][0],problemData[i][2][0],problemData[i][2][1])\n",
    "    img=img.toarray()\n",
    "    fig.add_subplot(1,nProblems,i+1)\n",
    "    plt.imshow(img)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# set up eps-scaling\n",
    "\n",
    "# geometric scaling from eps_start to eps_target in eps_steps+1 steps\n",
    "params.update(\n",
    "        Sinkhorn.Aux.SetupEpsScaling_Geometric(params[\"eps_target\"],params[\"eps_start\"],params[\"eps_steps\"],\\\n",
    "        verbose=True))\n",
    "\n",
    "# determine finest epsilon for each hierarchy level.\n",
    "# on coarsest level it is given by params[\"eps_boxScale\"]**params[\"eps_boxScale_power\"]\n",
    "# with each level, the finest scale params[\"eps_boxScale\"] is effectively divided by 2\n",
    "params[\"eps_scales\"]=[(params[\"eps_boxScale\"]/(2**n))**params[\"eps_boxScale_power\"]\\\n",
    "        for n in range(params[\"hierarchy_depth\"]+1)]+[0]\n",
    "print(\"eps_scales:\\t\",params[\"eps_scales\"])\n",
    "\n",
    "# divide eps_list into eps_lists, one for each hierarchy scale, divisions determined by eps_scales.\n",
    "params.update(Sinkhorn.Aux.SetupEpsScaling_Scales(params[\"eps_list\"],params[\"eps_scales\"],\\\n",
    "        levelTop=params[\"hierarchy_lTop\"], nOverlap=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "## setup hierarchical partitions\n",
    "partitionChildMode=HierarchicalPartition.THPMode_Tree\n",
    "\n",
    "# constructing basic partitions\n",
    "partitionList=[]\n",
    "for i in range(nProblems):\n",
    "    partition=HierarchicalPartition.GetPartition(problemData[i][1],params[\"hierarchy_depth\"],partitionChildMode,\\\n",
    "            box=None, signal_pos=True, signal_radii=True,clib=SolverCFC, export=False, verbose=False,\\\n",
    "            finestDimsWarning=False)\n",
    "    partitionList.append(partition)\n",
    "\n",
    "# exporting partitions\n",
    "pointerListPartition=np.zeros((nProblems),dtype=np.int64)\n",
    "for i in range(nProblems):\n",
    "    pointerListPartition[i]=SolverCFC.Export(partitionList[i])\n",
    "\n",
    "muHList=[SolverCFC.GetSignalMass(pointer,partition,aprob[0])\n",
    "        for pointer,partition,aprob in zip(pointerListPartition,partitionList,problemData)]\n",
    "\n",
    "# pointer lists\n",
    "pointerPosList=[HierarchicalPartition.getSignalPointer(partition,\"pos\") for partition in partitionList]\n",
    "pointerRadiiList=[HierarchicalPartition.getSignalPointer(partition,\"radii\",lBottom=partition.nlayers-2)\n",
    "        for partition in partitionList]\n",
    "pointerListPos=np.array([pointerPos.ctypes.data for pointerPos in pointerPosList],dtype=np.int64)\n",
    "pointerListRadii=np.array([pointerRadii.ctypes.data for pointerRadii in pointerRadiiList],dtype=np.int64)\n",
    "\n",
    "# print a few stats on the created problem\n",
    "for i,partition in enumerate(partitionList):\n",
    "    print(\"cells in partition {:d}: \".format(i),partition.cardLayers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Solver Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The code for the algorithm is extremely \"modularized\". Almost every part can be controlled by supplying suitable methods as parameters. This encompasses both modelling aspects such as different transport-type problems, as well as computational aspects, such as data structures for handling the kernels or cost functions.\n",
    "This requires a somewhat lengthy configuration sequence, and does not yield the fastest performance, but is very flexible and is thus useful for development."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# model specific stuff\n",
    "import Solvers.Sinkhorn.Models.OT as ModelOT\n",
    "\n",
    "if params[\"model_transportModel\"]==\"ot\":\n",
    "\n",
    "    method_CostFunctionProvider = lambda level, pointerAlpha, alphaFinest=None :\\\n",
    "            Sinkhorn.CInterface.Setup_CostFunctionProvider_SquaredEuclidean(pointerListPos,\\\n",
    "                    partitionList[0].ndim,level,pointerListRadii,pointerAlpha,alphaFinest)\n",
    "\n",
    "    method_iterate_iterate = lambda kernel, alphaList, scalingList, muList, pointerListScaling, pointerListMu,\\\n",
    "            eps, nInnerIterations: \\\n",
    "            ModelOT.Iterate(kernel[0],kernel[1],scalingList[0],scalingList[1],muList[0],muList[1],nInnerIterations)\n",
    "\n",
    "    def method_iterate_error(kernel, alphaList, scalingList, muList, pointerListScaling, pointerListMu, eps):\n",
    "            return ModelOT.ErrorMarginLInf(kernel[0],kernel[1],scalingList[0],scalingList[1],muList[0],muList[1])\n",
    "\n",
    "elif params[\"model_transportModel\"]==\"wf\":\n",
    "\n",
    "    import Solvers.Sinkhorn.Models.WF as ModelWF\n",
    "\n",
    "    method_CostFunctionProvider = lambda level, pointerAlpha, alphaFinest=None :\\\n",
    "            Sinkhorn.CInterface.Setup_CostFunctionProvider_SquaredEuclideanWF(pointerListPos,\\\n",
    "                    partitionList[0].ndim,level,pointerListRadii,pointerAlpha,alphaFinest,\\\n",
    "                    FR_kappa=params[\"model_FR_kappa\"],FR_cMax=params[\"model_FR_cMax\"])\n",
    "\n",
    "    method_iterate_iterate = lambda kernel, alphaList, scalingList, muList, pointerListScaling, pointerListMu,\\\n",
    "            eps, nInnerIterations: \\\n",
    "            ModelWF.Iterate(kernel[0],kernel[1],alphaList[0],alphaList[1],scalingList[0],scalingList[1],\\\n",
    "                    muList[0],muList[1],eps,params[\"model_FR_kappa\"],nInnerIterations)\n",
    "\n",
    "    method_iterate_error = lambda kernel, alphaList, scalingList, muList, pointerListScaling, pointerListMu, eps: \\\n",
    "            ModelWF.ScorePDGap(kernel[0],kernel[1],alphaList[0],alphaList[1],\\\n",
    "                    scalingList[0],scalingList[1],muList[0],muList[1],\\\n",
    "                    eps,params[\"model_FR_kappa\"])\n",
    "\n",
    "else:\n",
    "    raise ValueError(\"model_transportModel not recognized: \"+params[\"model_transportModel\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# data structure choice for kernel\n",
    "if params[\"setup_type_kernel\"]==\"csr\":\n",
    "    get_method_getKernel = lambda level, muList:\\\n",
    "            lambda kernel, alpha, eps:\\\n",
    "                    Sinkhorn.GetKernel_SparseCSR(\n",
    "                            partitionList,pointerListPartition,\\\n",
    "                            method_CostFunctionProvider,\\\n",
    "                            level, alpha, eps,\\\n",
    "                            kThresh=params[\"sparsity_kThresh\"],\\\n",
    "                            baseMeasureX=muList[0], baseMeasureY=muList[1],\\\n",
    "                            sanityCheck=False,\\\n",
    "                            verbose=paramsVerbose[\"solve_kernel\"])\n",
    "\n",
    "    method_deleteKernel = lambda kernel : None\n",
    "\n",
    "    method_refineKernel = lambda level, kernel, alphaList, muList, eps:\\\n",
    "                    Sinkhorn.RefineKernel_CSR(partitionList, pointerListPartition,\\\n",
    "                            method_CostFunctionProvider,\\\n",
    "                            level, (kernel[0].indices,kernel[0].indptr), alphaList,\\\n",
    "                            eps,\\\n",
    "                            baseMeasureX=muList[0], baseMeasureY=muList[1],\\\n",
    "                            verbose=paramsVerbose[\"solve_kernel\"])\n",
    "\n",
    "    method_getKernelVariablesCount=Sinkhorn.GetKernelVariablesCount_CSR\n",
    "\n",
    "elif params[\"setup_type_kernel\"]==\"dense\":\n",
    "    get_method_getKernel = lambda level, muList:\\\n",
    "            lambda kernel, alpha, eps:\\\n",
    "                    Sinkhorn.GetKernel_DenseArray(\n",
    "                            partitionList,pointerListPartition,\\\n",
    "                            method_CostFunctionProvider,\\\n",
    "                            level, alpha, eps,\\\n",
    "                            baseMeasureX=muList[0], baseMeasureY=muList[1],\\\n",
    "                            truncationThresh=1E-200,verbose=paramsVerbose[\"solve_kernel\"]\\\n",
    "                            )\n",
    "\n",
    "\n",
    "    method_deleteKernel = lambda kernel : None\n",
    "\n",
    "    method_refineKernel = lambda level, kernel, alphaList, muList, eps:\\\n",
    "            None\n",
    "    \n",
    "    method_getKernelVariablesCount=Sinkhorn.GetKernelVariablesCount_Array\n",
    "\n",
    "else:\n",
    "    raise ValueError(\"setup_type_kernel not recognized: \"+params[\"setup_type_kernel\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "if params[\"setup_doAbsorption\"]==1:\n",
    "    method_absorbScaling = lambda alphaList,scalingList,eps:\\\n",
    "                    Sinkhorn.Method_AbsorbScalings(alphaList,scalingList,eps,\\\n",
    "                            residualScaling=None,minAlpha=None,verbose=False)\n",
    "else:\n",
    "    method_absorbScaling = lambda alphaList,scalingList,eps: None\n",
    "\n",
    "get_method_update = lambda epsList, method_getKernel, method_deleteKernel, method_absorbScaling:\\\n",
    "        lambda status, data:\\\n",
    "                Sinkhorn.Update(status,data,epsList,\\\n",
    "                        method_getKernel, method_deleteKernel, method_absorbScaling,\\\n",
    "                        absorbFinalIteration=True,maxRepeats=params[\"sinkhorn_maxRepeats\"],\\\n",
    "                        verbose=paramsVerbose[\"solve_update\"]\\\n",
    "                        )\n",
    "\n",
    "\n",
    "method_iterate = lambda status,data : Sinkhorn.Method_IterateToPrecision(status,data,\\\n",
    "                method_iterate=method_iterate_iterate,method_error=method_iterate_error,\\\n",
    "                maxError=params[\"sinkhorn_error\"],\\\n",
    "                nInnerIterations=params[\"sinkhorn_nInner\"],maxOuterIterations=params[\"sinkhorn_maxOuter\"],\\\n",
    "                scalingBound=params[\"adaption_scalingBound\"],scalingLowerBound=params[\"adaption_scalingLowerBound\"],\\\n",
    "                verbose=paramsVerbose[\"solve_iterate\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "result=Sinkhorn.MultiscaleSolver(partitionList,pointerListPartition,muHList,params[\"eps_lists\"],\\\n",
    "        get_method_getKernel,method_deleteKernel,method_absorbScaling,\\\n",
    "        method_iterate,get_method_update,method_refineKernel,\\\n",
    "        params[\"hierarchy_lTop\"],params[\"hierarchy_lBottom\"],\\\n",
    "        collectReports=True,method_getKernelVariablesCount=method_getKernelVariablesCount,\\\n",
    "        verbose=paramsVerbose[\"solve_overview\"],\\\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# extract results of algorithm\n",
    "data=result[\"data\"]\n",
    "status=result[\"status\"]\n",
    "setup=result[\"setup\"]\n",
    "setupAux=result[\"setupAux\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# re-estimate kernel one last time\n",
    "data[\"kernel\"]=setupAux[\"method_getKernel\"](data[\"kernel\"],data[\"alpha\"],data[\"eps\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# clean up hieararchical partitions\n",
    "for pointer in pointerListPartition:\n",
    "    SolverCFC.Close(pointer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Post-Processing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Some post-processing for fun."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Extract Coupling"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Extract coupling from solution, test marginal accuracy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import Solvers.Sinkhorn.Models.Common as ModelCommon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "pi=ModelCommon.GetCouplingCSR(data[\"kernel\"][0],data[\"scaling\"][0],data[\"scaling\"][1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "if params[\"model_transportModel\"]==\"ot\":\n",
    "    # L^1 marginal errors\n",
    "    m0,m1=ModelCommon.GetMarginals(pi)\n",
    "    print([np.sum(np.abs(marg-probData[0])) for marg,probData in zip((m0,m1),problemData)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Displacement Interpolation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this section we compute a naive displacement interpolation from the optimal couplings. For classical optimal transport it is well known how to do this (in the continuum), for Wasserstein-Fisher-Rao / Hellinger-Kantorovich we refer to [Liero, Mielke, Savaré: 'Optimal Entropy-Transport problems and a new Hellinger-Kantorovich distance between positive measures'], see also [Chizat, Peyré, Schmitzer, Vialard: 'An Interpolating Distance between Optimal Transport and Fisher-Rao Metrics' and 'Unbalanced Optimal Transport: Geometry and Kantorovich Formulation'].\n",
    "\n",
    "Note that handling of the discretization is done in a very simplistic way: for each discrete mass particle, the optimal continuous trajectory over the image plane is computed. Each travelling particle is then projected onto the nearest pixels, its mass being distributed according to piecewise linear interpolation. This leads to oscillation-type artifacts when mass distributions are deformed smoothly. These artifacts are purely a result of the interpolation process.\n",
    "\n",
    "Nevertheless, these interpolations help to visualize the transport process."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "if params[\"model_transportModel\"]==\"ot\":\n",
    "    # extract data about travelling mass particles from coupling\n",
    "    particles=ModelOT.GetParticles(pi)\n",
    "    print(particles[0].shape)\n",
    "    method_interpolate=ModelOT.interpolateEuclidean\n",
    "elif params[\"model_transportModel\"]==\"wf\":\n",
    "    # extract data about travelling mass particles from coupling\n",
    "    particles=ModelWF.GetParticles(pi,problemData[0][0],problemData[1][0],1E-7)\n",
    "    print([x[0].shape for x in particles])\n",
    "    def method_interpolate(particles,pos0,pos1,t):\n",
    "            rhoPre=[ModelWF.interpolateEuclidean(p,posx,posy,t,params[\"model_FR_kappa\"])\\\n",
    "                    for p,posx,posy in zip(particles,[pos0,pos0,pos1],[pos1,pos0,pos1])]\n",
    "            rhoPos=np.vstack([x[0] for x in rhoPre])\n",
    "            rhoMass=np.hstack([x[1] for x in rhoPre])\n",
    "            return (rhoPos,rhoMass)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "datT=np.linspace(0,1,10)\n",
    "imgList=[]\n",
    "res=problemData[0][2]\n",
    "for i,t in enumerate(datT):\n",
    "\n",
    "    rho=method_interpolate(particles,problemData[0][1],problemData[1][1],t)\n",
    "\n",
    "    projection=OTTools.ProjectInterpolation2D(rho[0],rho[1],res[0],res[1]).toarray()\n",
    "    imgList.append(projection)\n",
    "    \n",
    "    plt.imshow(projection)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "massList=[np.sum(img) for img in imgList]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "plt.plot(massList)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.4.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
